{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 4: 模型平均的联邦学习\n",
    "\n",
    "**概述:** 在本教程的第2部分中，我们使用了非常简单的联邦学习版本来训练模型。这要求每个数据所有者信任模型所有者才能看到其梯度。\n",
    "\n",
    "**说明:** 在本教程中，我们将展示如何使用第3部分中的高级聚合工具来允许参数由可信的“安全工作机”聚合，然后将最终结果模型发送回模型所有者（我们）。\n",
    "\n",
    "这样，只有安全工作机才能看到谁的模型参数来自谁。我们也许能够知道模型的哪些部分发生了更改，但是我们**不**知道哪个工作人员（Bob或Alice）进行了哪些更改，从而创建了一层隐私。\n",
    "\n",
    "作者:\n",
    " - Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    " - Jason Mancuso - Twitter: [@jvmancuso](https://twitter.com/jvmancuso)\n",
    "\n",
    "中文版译者：\n",
    "- Hou Wei - github：[@dljgs1](https://github.com/dljgs1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import syft as sy\n",
    "import copy\n",
    "hook = sy.TorchHook(torch)\n",
    "from torch import nn, optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第一步： 建立数据所有者\n",
    "\n",
    "首先，我们将创建两个数据所有者（Bob和Alice），每个数据所有者拥有少量数据。 我们还将初始化一个名为“secure_worker”的安全机器。实际上，这可以是安全的硬件（例如英特尔的SGX），也可以只是受信任的中介。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一对工作机\n",
    "\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "secure_worker = sy.VirtualWorker(hook, id=\"secure_worker\")\n",
    "\n",
    "\n",
    "# 玩具数据集\n",
    "data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)\n",
    "target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)\n",
    "\n",
    "# 通过以下方式获取每个工作机的训练数据的指针\n",
    "# 向bob和alice发送一些训练数据\n",
    "bobs_data = data[0:2].send(bob)\n",
    "bobs_target = target[0:2].send(bob)\n",
    "\n",
    "alices_data = data[2:].send(alice)\n",
    "alices_target = target[2:].send(alice)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第二步: 建立我们的模型\n",
    "\n",
    "对于此示例，我们将使用简单的线性模型进行训练。 我们通常可以使用PyTorch的nn.Linear构造函数对其进行初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化玩具模型\n",
    "model = nn.Linear(2,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第三步：发送模型的拷贝给Alice和Bob\n",
    "\n",
    "接下来，我们需要将当前模型的副本发送给Alice和Bob，以便他们可以对自己的数据集执行学习步骤。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bobs_model = model.copy().send(bob)\n",
    "alices_model = model.copy().send(alice)\n",
    "\n",
    "bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)\n",
    "alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第4步：训练鲍勃和爱丽丝的模型（并行）\n",
    "\n",
    "与通过安全平均进行联邦学习的常规做法一样，每个数据所有者首先在本地对模型进行几次迭代训练，然后再对模型进行平均。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "\n",
    "    # 训练Bob的模型\n",
    "    bobs_opt.zero_grad()\n",
    "    bobs_pred = bobs_model(bobs_data)\n",
    "    bobs_loss = ((bobs_pred - bobs_target)**2).sum()\n",
    "    bobs_loss.backward()\n",
    "\n",
    "    bobs_opt.step()\n",
    "    bobs_loss = bobs_loss.get().data\n",
    "\n",
    "    # 训练Alice的模型\n",
    "    alices_opt.zero_grad()\n",
    "    alices_pred = alices_model(alices_data)\n",
    "    alices_loss = ((alices_pred - alices_target)**2).sum()\n",
    "    alices_loss.backward()\n",
    "\n",
    "    alices_opt.step()\n",
    "    alices_loss = alices_loss.get().data\n",
    "    \n",
    "    print(\"Bob:\" + str(bobs_loss) + \" Alice:\" + str(alices_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第5步：将两个更新的模型发送到安全工作机\n",
    "\n",
    "现在，每个数据所有者都拥有部分受过训练的模型，是时候以安全的方式将它们平均在一起了。我们通过指示Alice和Bob将其模型发送到安全（可信）服务器来实现这一目标。\n",
    "\n",
    "请注意，这种使用我们的API的方式意味着每个模型都**直接**发送到secure_worker。我们从未见过。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alices_model.move(secure_worker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bobs_model.move(secure_worker)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第6步：模型平均"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，此训练epoch（译者注：一个epoch表示全部训练数据完整训练一轮）的最后一步是将Bob和Alice的训练模型平均在一起，然后使用它来设置全局“模型”的值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())\n",
    "    model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 冲洗并重复\n",
    "\n",
    "现在，我们只需要对此进行多次迭代！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 10\n",
    "worker_iters = 5\n",
    "\n",
    "for a_iter in range(iterations):\n",
    "    \n",
    "    bobs_model = model.copy().send(bob)\n",
    "    alices_model = model.copy().send(alice)\n",
    "\n",
    "    bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)\n",
    "    alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)\n",
    "\n",
    "    for wi in range(worker_iters):\n",
    "\n",
    "        # 训练Bob的模型\n",
    "        bobs_opt.zero_grad()\n",
    "        bobs_pred = bobs_model(bobs_data)\n",
    "        bobs_loss = ((bobs_pred - bobs_target)**2).sum()\n",
    "        bobs_loss.backward()\n",
    "\n",
    "        bobs_opt.step()\n",
    "        bobs_loss = bobs_loss.get().data\n",
    "\n",
    "        # 训练Alice的模型\n",
    "        alices_opt.zero_grad()\n",
    "        alices_pred = alices_model(alices_data)\n",
    "        alices_loss = ((alices_pred - alices_target)**2).sum()\n",
    "        alices_loss.backward()\n",
    "\n",
    "        alices_opt.step()\n",
    "        alices_loss = alices_loss.get().data\n",
    "    \n",
    "    alices_model.move(secure_worker)\n",
    "    bobs_model.move(secure_worker)\n",
    "    with torch.no_grad():\n",
    "        model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())\n",
    "        model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())\n",
    "    \n",
    "    print(\"Bob:\" + str(bobs_loss) + \" Alice:\" + str(alices_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，我们想确保我们得到的模型学习正确，因此我们将在测试数据集上对其进行评估。在这个玩具问题中，我们将使用原始数据，但在实践中，我们将希望使用新数据来了解模型对看不见的样本的泛化程度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model(data)\n",
    "loss = ((preds - target) ** 2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(preds)\n",
    "print(target)\n",
    "print(loss.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这个玩具示例中，平均模型相对于本地训练的纯文本模型表现不佳，但是我们能够在不暴露每个工人的训练数据的情况下对其进行训练。我们还能够在可信任的聚合器上聚合每个工作人员的更新模型，以防止数据泄露给模型所有者。\n",
    "\n",
    "在未来的教程中，我们的目标是直接使用梯度进行可信聚合，以便我们可以使用更好的梯度估计来更新模型并获得更强大的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 恭喜!!! 是时候加入社区了!\n",
    "\n",
    "祝贺您完成本笔记本教程！ 如果您喜欢此方法，并希望加入保护隐私、去中心化AI和AI供应链（数据）所有权的运动，则可以通过以下方式做到这一点！\n",
    "\n",
    "### 给 PySyft 加星\n",
    "\n",
    "帮助我们的社区的最简单方法是仅通过给GitHub存储库加注星标！ 这有助于提高人们对我们正在构建的出色工具的认识。\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### 选择我们的教程\n",
    "\n",
    "我们编写了非常不错的教程，以更好地了解联合学习和隐私保护学习的外观，以及我们如何为实现这一目标添砖加瓦。\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### 加入我们的 Slack!\n",
    "\n",
    "保持最新进展的最佳方法是加入我们的社区！ 您可以通过填写以下表格来做到这一点[http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### 加入代码项目!\n",
    "\n",
    "对我们的社区做出贡献的最好方法是成为代码贡献者！ 您随时可以转到PySyft GitHub的Issue页面并过滤“projects”。这将向您显示所有概述，选择您可以加入的项目！如果您不想加入项目，但是想做一些编码，则还可以通过搜索标记为“good first issue”的GitHub问题来寻找更多的“一次性”微型项目。\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 捐赠\n",
    "\n",
    "如果您没有时间为我们的代码库做贡献，但仍想提供支持，那么您也可以成为Open Collective的支持者。所有捐款都将用于我们的网络托管和其他社区支出，例如黑客马拉松和聚会！\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
